import time
from itertools import repeat, chain, islice

import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from collections import defaultdict
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import wandb
import os

from knowledge_tracing.args import ARGS
from knowledge_tracing.network.util_network import NoamOpt


class NoamOptimizer:
    '''
    Wrapper for Noam learning rate scheme with Adam optimizer.
    '''
    def __init__(self, model, lr, model_size, warmup):
        self._adam = torch.optim.Adam(model.parameters(), lr=lr)
        self._opt = NoamOpt(
            model_size=model_size, factor=1, warmup=warmup, optimizer=self._adam)

    def step(self, loss):
        self._opt.zero_grad()
        loss.backward()
        self._opt.step()


class Trainer:
    def __init__(self, model, device, warm_up_step_count,
                 d_model, use_wandb, num_epochs, num_steps, weight_path, learning_rate,
                 train_data, val_data, test_data, run_i=1):
        self._device = device
        self._use_wandb = use_wandb
        self._weight_path = weight_path
        self._num_epochs = num_epochs
        self._num_steps = num_steps

        self._train_data = train_data
        self._val_data = val_data
        self._test_data = test_data

        self._model = model
        self._loss_fn = nn.BCELoss(reduction='none')
        self._model.to(device)

        self.additional_losses_weight = {
            'rep': ARGS.rep_weight,
            'ins': ARGS.ins_weight,
            'del': ARGS.del_weight,
            'lap': ARGS.lap_weight
        }

        self.aug_kt_loss = {
            'rep': ARGS.rep_kt_loss,
            'ins': ARGS.ins_kt_loss,
            'del': ARGS.del_kt_loss
        }

        self._opt = NoamOptimizer(model=model, lr=learning_rate, model_size=d_model, warmup=warm_up_step_count)

        self.step = 0
        self.threshold = 0.5
        self.max_acc = 0
        self.max_auc = 0
        self.max_step = 0

        self.test_acc = 0
        self.test_auc = 0

        self.run_i = run_i

        self._early_stop_cnt = 0

        self._is_train = True

    def train(self):
        '''
        Main training loop.
        '''
        train_gen = data.DataLoader(
            dataset=self._train_data, shuffle=True,
            batch_size=ARGS.train_batch, num_workers=ARGS.num_workers)
        val_gen = data.DataLoader(
            dataset=self._val_data, shuffle=False,
            batch_size=ARGS.test_batch, num_workers=ARGS.num_workers)

        # will train self._num_epochs copies of train data
        to_train = chain.from_iterable(repeat(train_gen, self._num_epochs))
        # consisting of total_steps batches
        if self._num_steps > 0:
            total_steps = self._num_steps
        else:
            total_steps = len(train_gen) * self._num_epochs
        print(f"total_steps: {total_steps}")  # for monitoring

        # train & validation
        self.step = 0
        while self.step < total_steps:
            rem_steps = total_steps - self.step
            num_steps = min(rem_steps, ARGS.eval_steps)
            self.step += num_steps

            # take num_steps batches from to_train stream
            train_batches = islice(to_train, num_steps)
            self._train(train_batches, num_steps)

            cur_weight = self._model.state_dict()
            torch.save(cur_weight, f'{self._weight_path}{self.step}.pt')
            self._test(val_gen, 'Valid')
            print(f'Current best weight: {self.max_step}.pt, best auc: {self.max_auc:.4f}')
            # remove all weight file except {self.max_step}.pt
            weight_list = os.listdir(self._weight_path)
            for w in weight_list:
                if int(w[:-3]) != self.max_step:
                    os.unlink(f'{self._weight_path}{w}')
            # early stopping
            if self._early_stop_cnt == 50:
                print("early stop!")
                break

    def test(self, weight_num=0):
        test_gen = data.DataLoader(
            dataset=self._test_data, shuffle=False,
            batch_size=ARGS.test_batch, num_workers=ARGS.num_workers)

        # load best weight
        if self.max_step != 0:
            weight_num = self.max_step
        weight_path = f'{ARGS.weight_path}{weight_num}.pt'
        print(f'best weight: {weight_path}')
        self._model.load_state_dict(torch.load(weight_path))
        self._test(test_gen, 'Test')

    def _forward(self, batch):
        """
        All the outputs are dictionary of tensors.
        Keys are augmentation methods: ori/ins/del/rep
        """
        label = {}
        loss_mask = {}
        for aug in batch:
            batch[aug] = {k: t.to(self._device) for k, t in batch[aug].items()}

        label['ori'] = 2 - batch['ori']['is_correct']
        loss_mask['ori'] = batch['ori']['loss_mask']
        if self._is_train:
            for aug in ARGS.augmentations:
                label[aug] = 2 - batch[aug]['is_correct']
                loss_mask[aug] = batch[aug]['loss_mask']

        output, additional_losses = self._model(batch)
        pred = {}  # dictionary of predicted responses, Long type tensor
        for aug in output:
            output[aug] = output[aug].squeeze(-1)
            pred[aug] = (output[aug] >= self.threshold).long()
        return label, output, pred, loss_mask, additional_losses

    def _get_loss(self, label, output, loss_mask):
        loss = self._loss_fn(output, label.float())
        # Mask loss from padding
        loss = loss.masked_fill(~loss_mask, 0)
        # origin
        weighted_loss = loss
        label_correct_count = (label != 2).sum(-1)
        length = label_correct_count.type_as(weighted_loss)
        loss_by_batch = weighted_loss.sum(-1) / length
        loss_by_batch = loss_by_batch.masked_fill(length == 0, 0)
        loss = loss_by_batch.mean()
        return loss

    # takes iterator
    def _train(self, batch_iter, num_batches):
        '''
        Args:
            batch_iter: Iterator containing batches of training data.
            num_batches: int, number of steps to train for. Used solely
                for tqdm.
        '''
        start_time = time.time()

        self._model.train()
        self._is_train = True

        losses = []
        kt_losses = []
        additional_losses_dict = defaultdict(list)
        num_corrects = 0
        num_total = 0

        for batch in tqdm(batch_iter, total=num_batches):
            label, out, pred, loss_mask, additional_losses = self._forward(batch)
            kt_loss = self._get_loss(label['ori'], out['ori'], loss_mask['ori'])
            num_corrects += (pred['ori'] == label['ori']).masked_fill(~loss_mask['ori'], 0).sum().item()
            num_total += loss_mask['ori'].long().sum().item()
            for aug in ARGS.augmentations:
                if self.aug_kt_loss[aug]:
                    # include KT loss for augmented examples also
                    kt_loss_aug = self._get_loss(label[aug], out[aug], loss_mask[aug])
                    kt_loss += kt_loss_aug
            if additional_losses is not None:
                kt_losses.append(kt_loss.item())
                for k in additional_losses:
                    additional_losses_dict[k].append(additional_losses[k].item())
                additional_losses_sum = sum(additional_losses[k] * self.additional_losses_weight[k] for k in additional_losses)
                train_loss = kt_loss + additional_losses_sum
            else:
                train_loss = kt_loss
            losses.append(train_loss.item())
            self._opt.step(train_loss)

        acc = num_corrects / num_total
        loss = np.mean(losses)
        kt_loss = np.mean(kt_losses)
        training_time = time.time() - start_time

        print(f'correct: {num_corrects}, total: {num_total}')
        print(f'[Train]     time: {training_time:.2f}, loss: {loss:.4f}, acc: {acc:.4f}')

        if ARGS.use_wandb:
            wandb.log({f'Train acc {self.run_i}': acc, f'Train loss {self.run_i}': loss},
                      step=self.step)
            if len(additional_losses_dict) > 0:
                wandb.log({f'KT loss {self.run_i}': kt_loss},
                          step=self.step)
                for k in additional_losses_dict:
                    wandb.log({f'{k} loss {self.run_i}': np.mean(additional_losses_dict[k])},
                              step=self.step)

    # takes iterable
    def _test(self, batches, status):
        '''
        Args:
            batches: Iterable containing batches of test data.
            status: 'Valid' or 'Test'
        '''
        start_time = time.time()

        self._model.eval()
        self._is_train = False

        losses = []
        num_corrects = 0
        num_total = 0
        labels = []
        outs = []

        with torch.no_grad():
            for batch in tqdm(batches):
                # if 'ori' in batch:
                label, out, pred, loss_mask, additional_losses = self._forward(batch)
                test_loss = self._get_loss(label['ori'], out['ori'], loss_mask['ori'])
                seq_size = batch['ori']['sequence_size'].to(ARGS.device)
                loss_mask = loss_mask['ori'].gather(-1, seq_size - 1)
                label = label['ori'].gather(-1, seq_size - 1)
                out = out['ori'].gather(-1, seq_size - 1)
                pred = pred['ori'].gather(-1, seq_size - 1)
                if additional_losses is not None:
                    additional_losses_sum = sum(additional_losses[k] * self.additional_losses_weight[k] for k in additional_losses)
                    test_loss += additional_losses_sum
                losses.append(test_loss.item())

                num_corrects += (pred == label).masked_fill(~loss_mask, 0).sum().item()
                num_total += loss_mask.long().sum().item()

                label = label.squeeze(-1).data.cpu().numpy()
                out = out.squeeze(-1).data.cpu().numpy()
                loss_mask = loss_mask.squeeze(-1).data.cpu().numpy()

                label = np.extract(loss_mask, label)
                out = np.extract(loss_mask, out)

                labels.extend(label)
                outs.extend(out)

        acc = num_corrects / num_total
        auc = roc_auc_score(labels, outs)
        loss = np.mean(losses)
        training_time = time.time() - start_time

        print(f'correct: {num_corrects}, total: {num_total}')
        print(f'[{status}]      time: {training_time:.2f}, loss: {loss:.4f}, acc: {acc:.4f}, auc: {auc:.4f}')

        if status == 'Valid':
            if auc > self.max_auc:
                # update best step
                self.max_step = self.step
                self._early_stop_cnt = 0
            self.max_acc = max(self.max_acc, acc)
            self.max_auc = max(self.max_auc, auc)

            # for early stopping
            if auc < self.max_auc:
                self._early_stop_cnt += 1

        elif status == 'Test':
            self.test_acc = acc
            self.test_auc = auc

        if ARGS.use_wandb:
            if status == 'Valid':
                wandb.log({
                    f'Best Valid acc {self.run_i}': self.max_acc,
                    f'Best Valid auc {self.run_i}': self.max_auc
                }, step=self.step)
            wandb.log({
                f'{status} acc {self.run_i}': acc,
                f'{status} auc {self.run_i}': auc,
                f'{status} loss {self.run_i}': loss
            }, step=self.step)

